In [1]:
# https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/train-rt-detr-on-custom-dataset-with-transformers.ipynb?ref=blog.roboflow.com#scrollTo=h8KHJfSZFh7L
# https://blog.roboflow.com/train-rt-detr-custom-dataset-transformers/
In [2]:
!nvidia-smi
Sun Sep  8 16:33:02 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 552.12                 Driver Version: 552.12         CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                     TCC/WDDM  | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 3080 ...  WDDM  |   00000000:01:00.0  On |                  N/A |
| N/A   54C    P8             23W /  130W |     952MiB /   8192MiB |     22%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A      3048    C+G   ...ft Office\root\Office16\WINWORD.EXE      N/A      |
|    0   N/A  N/A      4360    C+G   C:\Windows\explorer.exe                     N/A      |
|    0   N/A  N/A      5548    C+G   ...crosoft\Edge\Application\msedge.exe      N/A      |
|    0   N/A  N/A      9828    C+G   ...5n1h2txyewy\ShellExperienceHost.exe      N/A      |
|    0   N/A  N/A      9840    C+G   ...8bbwe\SnippingTool\SnippingTool.exe      N/A      |
|    0   N/A  N/A      9868    C+G   ...2txyewy\StartMenuExperienceHost.exe      N/A      |
|    0   N/A  N/A     10552    C+G   ...ekyb3d8bbwe\PhoneExperienceHost.exe      N/A      |
|    0   N/A  N/A     14224    C+G   ...nt.CBS_cw5n1h2txyewy\SearchHost.exe      N/A      |
|    0   N/A  N/A     14368    C+G   ...oogle\Chrome\Application\chrome.exe      N/A      |
|    0   N/A  N/A     15036    C+G   ...590_x64__8wekyb3d8bbwe\ms-teams.exe      N/A      |
|    0   N/A  N/A     17884    C+G   ...CBS_cw5n1h2txyewy\TextInputHost.exe      N/A      |
|    0   N/A  N/A     19196    C+G   ...\Docker\frontend\Docker Desktop.exe      N/A      |
|    0   N/A  N/A     20192    C+G   ...__8wekyb3d8bbwe\Notepad\Notepad.exe      N/A      |
|    0   N/A  N/A     22676    C+G   ...590_x64__8wekyb3d8bbwe\ms-teams.exe      N/A      |
|    0   N/A  N/A     25416    C+G   ...on\128.0.2739.63\msedgewebview2.exe      N/A      |
|    0   N/A  N/A     26480    C+G   ...cal\Microsoft\OneDrive\OneDrive.exe      N/A      |
|    0   N/A  N/A     27752    C+G   ...siveControlPanel\SystemSettings.exe      N/A      |
|    0   N/A  N/A     27904    C+G   ...t.LockApp_cw5n1h2txyewy\LockApp.exe      N/A      |
|    0   N/A  N/A     27996    C+G   ...590_x64__8wekyb3d8bbwe\ms-teams.exe      N/A      |
|    0   N/A  N/A     30036    C+G   ...__8wekyb3d8bbwe\WindowsTerminal.exe      N/A      |
|    0   N/A  N/A     30144    C+G   ...soft Office\root\Office16\EXCEL.EXE      N/A      |
|    0   N/A  N/A     31472    C+G   ...42.0_x64__8wekyb3d8bbwe\GameBar.exe      N/A      |
|    0   N/A  N/A     32728    C+G   ...cal\Microsoft\OneDrive\OneDrive.exe      N/A      |
|    0   N/A  N/A     40580    C+G   ...5.8.2.0_x64__htrsf667h5kn2\AWCC.exe      N/A      |
+-----------------------------------------------------------------------------------------+
In [4]:
!python -V
Python 3.11.7
In [3]:
# !pip install -q git+https://github.com/huggingface/transformers.git
# !pip install -q git+https://github.com/roboflow/supervision.git
# !pip install -q accelerate
# !pip install -q roboflow
# !pip install -q torchmetrics
# !pip install -q "albumentations>=1.4.5"
In [4]:
import torch #
import requests #

import numpy as np #
import supervision as sv #
import albumentations as A #

from PIL import Image #
from pprint import pprint #
from roboflow import Roboflow
from dataclasses import dataclass, replace #
# from google.colab import userdata
from torch.utils.data import Dataset
from transformers import (
    AutoImageProcessor,
    AutoModelForObjectDetection,
    TrainingArguments,
    Trainer
)
from torchmetrics.detection.mean_ap import MeanAveragePrecision #
In [5]:
# @title Load model

# CHECKPOINT = "PekingU/rtdetr_r50vd_coco_o365"
# DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# model = AutoModelForObjectDetection.from_pretrained(CHECKPOINT).to(DEVICE)
# processor = AutoImageProcessor.from_pretrained(CHECKPOINT)

# Use the main detr-resnet-50 rather than the RT (Real-Time) one

CHECKPOINT = "facebook/detr-resnet-50"
# CHECKPOINT = "PekingU/rtdetr_r50vd_coco_o365"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

processor = AutoImageProcessor.from_pretrained(CHECKPOINT)
model = AutoModelForObjectDetection.from_pretrained(CHECKPOINT).to(DEVICE)
Some weights of the model checkpoint at facebook/detr-resnet-50 were not used when initializing DetrForObjectDetection: ['model.backbone.conv_encoder.model.layer1.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer2.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer3.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer4.0.downsample.1.num_batches_tracked']
- This IS expected if you are initializing DetrForObjectDetection from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DetrForObjectDetection from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
In [6]:
URL = "https://media.roboflow.com/notebooks/examples/dog.jpeg"

image = Image.open(requests.get(URL, stream=True).raw)
inputs = processor(image, return_tensors="pt").to(DEVICE)

with torch.no_grad():
    outputs = model(**inputs)

w, h = image.size
results = processor.post_process_object_detection(
    outputs, target_sizes=[(h, w)], threshold=0.3)
In [7]:
# @title Display result with NMS

detections = sv.Detections.from_transformers(results[0])
labels = [
    model.config.id2label[class_id]
    for class_id
    in detections.class_id
]

annotated_image = image.copy()
annotated_image = sv.BoxAnnotator().annotate(annotated_image, detections)
annotated_image = sv.LabelAnnotator().annotate(annotated_image, detections, labels=labels)
annotated_image.thumbnail((600, 600))
annotated_image
Out[7]:
No description has been provided for this image
In [8]:
# !pip install roboflow

# from roboflow import Roboflow
rf = Roboflow(api_key="ZlLYb0hY8PrrxeT2vL0E")
project = rf.workspace("assuralabs").project("av-water-damage")
version = project.version(1)
dataset = version.download("coco")
loading Roboflow workspace...
loading Roboflow project...
In [9]:
ds_train = sv.DetectionDataset.from_coco(
    images_directory_path=f"{dataset.location}/train",
    annotations_path=f"{dataset.location}/train/_annotations.coco.json",
)
ds_valid = sv.DetectionDataset.from_coco(
    images_directory_path=f"{dataset.location}/valid",
    annotations_path=f"{dataset.location}/valid/_annotations.coco.json",
)
ds_test = sv.DetectionDataset.from_coco(
    images_directory_path=f"{dataset.location}/test",
    annotations_path=f"{dataset.location}/test/_annotations.coco.json",
)

# duplicating the ds_train_temp because albumentations will heavily modify each original image
# Will reduce epochs to 64 because this is twice as much training data
# ds_train = sv.DetectionDataset.merge([ds_train_temp, ds_train_temp])

print(f"Number of training images: {len(ds_train)}")
print(f"Number of validation images: {len(ds_valid)}")
print(f"Number of test images: {len(ds_test)}")
Number of training images: 747
Number of validation images: 215
Number of test images: 183
In [10]:
# @title Display dataset sample

GRID_SIZE = 5

def annotate(image, annotations, classes):
    labels = [
        classes[class_id]
        for class_id
        in annotations.class_id
    ]

    bounding_box_annotator = sv.BoxAnnotator()
    label_annotator = sv.LabelAnnotator(text_scale=1, text_thickness=2)

    annotated_image = image.copy()
    annotated_image = bounding_box_annotator.annotate(annotated_image, annotations)
    annotated_image = label_annotator.annotate(annotated_image, annotations, labels=labels)
    return annotated_image

annotated_images = []
for i in range(GRID_SIZE * GRID_SIZE):
    _, image, annotations = ds_train[i]
    annotated_image = annotate(image, annotations, ds_train.classes)
    annotated_images.append(annotated_image)

grid = sv.create_tiles(
    annotated_images,
    grid_size=(GRID_SIZE, GRID_SIZE),
    single_tile_size=(400, 400),
    tile_padding_color=sv.Color.WHITE,
    tile_margin_color=sv.Color.WHITE
)
sv.plot_image(grid, size=(10, 10))
No description has been provided for this image
In [11]:
IMAGE_SIZE = 640

processor = AutoImageProcessor.from_pretrained(
    CHECKPOINT,
    do_resize=True,
    size={"width": IMAGE_SIZE, "height": IMAGE_SIZE},
)
In [12]:
train_augmentation_and_transform = A.Compose(
    [
        A.Perspective(p=0.1),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.RandomRotate90(p=0.5),
        A.Transpose(p=0.5),
        A.RandomBrightnessContrast(p=0.5),
        A.HueSaturationValue(p=0.1),
        
    ],
    bbox_params=A.BboxParams(
        format="pascal_voc",
        label_fields=["category"],
        clip=True,
        min_area=25
    ),
)

valid_transform = A.Compose(
    [A.NoOp()],
    bbox_params=A.BboxParams(
        format="pascal_voc",
        label_fields=["category"],
        clip=True,
        min_area=1
    ),
)
In [13]:
# @title Visualize some augmented images

IMAGE_COUNT = 5

for i in range(IMAGE_COUNT):
    _, image, annotations = ds_train[i]

    output = train_augmentation_and_transform(
        image=image,
        bboxes=annotations.xyxy,
        category=annotations.class_id
    )

    augmented_image = output["image"]
    augmented_annotations = replace(
        annotations,
        xyxy=np.array(output["bboxes"]),
        class_id=np.array(output["category"])
    )

    annotated_images = [
        annotate(image, annotations, ds_train.classes),
        annotate(augmented_image, augmented_annotations, ds_train.classes)
    ]
    grid = sv.create_tiles(
        annotated_images,
        titles=['original', 'augmented'],
        titles_scale=0.5,
        single_tile_size=(400, 400),
        tile_padding_color=sv.Color.WHITE,
        tile_margin_color=sv.Color.WHITE
    )
    sv.plot_image(grid, size=(6, 6))
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
In [14]:
class PyTorchDetectionDataset(Dataset):
    def __init__(self, dataset: sv.DetectionDataset, processor, transform: A.Compose = None):
        self.dataset = dataset
        self.processor = processor
        self.transform = transform

    @staticmethod
    def annotations_as_coco(image_id, categories, boxes):
        annotations = []
        for category, bbox in zip(categories, boxes):
            x1, y1, x2, y2 = bbox
            formatted_annotation = {
                "image_id": image_id,
                "category_id": category,
                "bbox": [x1, y1, x2 - x1, y2 - y1],
                "iscrowd": 0,
                "area": (x2 - x1) * (y2 - y1),
            }
            annotations.append(formatted_annotation)

        return {
            "image_id": image_id,
            "annotations": annotations,
        }

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        _, image, annotations = self.dataset[idx]

        # Convert image to RGB numpy array
        image = image[:, :, ::-1]
        boxes = annotations.xyxy
        categories = annotations.class_id

        if self.transform:
            transformed = self.transform(
                image=image,
                bboxes=boxes,
                category=categories
            )
            image = transformed["image"]
            boxes = transformed["bboxes"]
            categories = transformed["category"]


        formatted_annotations = self.annotations_as_coco(
            image_id=idx, categories=categories, boxes=boxes)
        result = self.processor(
            images=image, annotations=formatted_annotations, return_tensors="pt")

        # Image processor expands batch dimension, lets squeeze it
        result = {k: v[0] for k, v in result.items()}

        return result
In [15]:
import random
random.seed(3)

pytorch_dataset_train_temp_1 = PyTorchDetectionDataset(
    ds_train, processor, transform=train_augmentation_and_transform)

random.seed(5)

pytorch_dataset_train_temp_2 = PyTorchDetectionDataset(
    ds_train, processor, transform=train_augmentation_and_transform)

random.seed(8)

pytorch_dataset_train_temp_3 = PyTorchDetectionDataset(
    ds_train, processor, transform=train_augmentation_and_transform)

random.seed(13)

pytorch_dataset_train_temp_4 = PyTorchDetectionDataset(
    ds_train, processor, transform=train_augmentation_and_transform)

random.seed(21)

pytorch_dataset_train_temp_5 = PyTorchDetectionDataset(
    ds_train, processor, transform=train_augmentation_and_transform)

random.seed(34)

pytorch_dataset_train_temp_6 = PyTorchDetectionDataset(
    ds_train, processor, transform=train_augmentation_and_transform)

random.seed(55)

pytorch_dataset_train_temp_7 = PyTorchDetectionDataset(
    ds_train, processor, transform=train_augmentation_and_transform)

random.seed(89)

pytorch_dataset_train_temp_8 = PyTorchDetectionDataset(
    ds_train, processor, transform=train_augmentation_and_transform)

random.seed(144)

pytorch_dataset_train_temp_9 = PyTorchDetectionDataset(
    ds_train, processor, transform=train_augmentation_and_transform)

random.seed(233)

pytorch_dataset_train_temp_10 = PyTorchDetectionDataset(
    ds_train, processor, transform=train_augmentation_and_transform)


pytorch_dataset_valid = PyTorchDetectionDataset(
    ds_valid, processor, transform=valid_transform)
pytorch_dataset_test = PyTorchDetectionDataset(
    ds_test, processor, transform=valid_transform)

pytorch_dataset_trains = []
pytorch_dataset_trains.append(pytorch_dataset_train_temp_1)
pytorch_dataset_trains.append(pytorch_dataset_train_temp_2)
pytorch_dataset_trains.append(pytorch_dataset_train_temp_3)
pytorch_dataset_trains.append(pytorch_dataset_train_temp_4)
pytorch_dataset_trains.append(pytorch_dataset_train_temp_5)
pytorch_dataset_trains.append(pytorch_dataset_train_temp_6)
pytorch_dataset_trains.append(pytorch_dataset_train_temp_7)
pytorch_dataset_trains.append(pytorch_dataset_train_temp_8)
pytorch_dataset_trains.append(pytorch_dataset_train_temp_9)
pytorch_dataset_trains.append(pytorch_dataset_train_temp_10)

pytorch_dataset_train = torch.utils.data.ConcatDataset(pytorch_dataset_trains)

pytorch_dataset_train[15]
The `max_size` parameter is deprecated and will be removed in v4.26. Please specify in `size['longest_edge'] instead`.
Out[15]:
{'pixel_values': tensor([[[-0.5253, -0.5253, -0.5253,  ...,  1.1187,  1.2557,  1.7352],
          [-0.5253, -0.5253, -0.5082,  ...,  1.1872,  1.3070,  1.7694],
          [-0.5253, -0.5082, -0.5082,  ...,  1.2899,  1.3927,  1.8037],
          ...,
          [ 1.4098,  1.3927,  1.3927,  ...,  0.0912,  0.1254,  0.1426],
          [ 1.4098,  1.3927,  1.3927,  ...,  0.0912,  0.1254,  0.1426],
          [ 1.4098,  1.3927,  1.3927,  ...,  0.0912,  0.1254,  0.1426]],
 
         [[-0.6176, -0.6176, -0.6176,  ...,  1.2906,  1.4307,  1.9209],
          [-0.6176, -0.6176, -0.6001,  ...,  1.3606,  1.4832,  1.9559],
          [-0.6176, -0.6001, -0.6001,  ...,  1.4657,  1.5707,  1.9909],
          ...,
          [ 1.5357,  1.5182,  1.5182,  ...,  0.1877,  0.2227,  0.2402],
          [ 1.5357,  1.5182,  1.5182,  ...,  0.1877,  0.2227,  0.2402],
          [ 1.5357,  1.5182,  1.5182,  ...,  0.1877,  0.2227,  0.2402]],
 
         [[-0.3927, -0.3927, -0.3927,  ...,  1.4025,  1.5420,  2.0474],
          [-0.3927, -0.3927, -0.3753,  ...,  1.4548,  1.5942,  2.0997],
          [-0.3927, -0.3753, -0.3753,  ...,  1.5768,  1.6988,  2.1346],
          ...,
          [ 1.6117,  1.5942,  1.5942,  ...,  0.1302,  0.1825,  0.1825],
          [ 1.6117,  1.5942,  1.5942,  ...,  0.1302,  0.1825,  0.1825],
          [ 1.6117,  1.5942,  1.5942,  ...,  0.1302,  0.1825,  0.1825]]]),
 'pixel_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         ...,
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1]]),
 'labels': {'size': tensor([640, 640]), 'image_id': tensor([15]), 'class_labels': tensor([2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 1, 1]), 'boxes': tensor([[0.5598, 0.5016, 0.8805, 0.9969],
         [0.8109, 0.8906, 0.3375, 0.2188],
         [0.7598, 0.6297, 0.2242, 0.3094],
         [0.5887, 0.8078, 0.1008, 0.2469],
         [0.4871, 0.5645, 0.1570, 0.2445],
         [0.3824, 0.3246, 0.1789, 0.2305],
         [0.6180, 0.3820, 0.2172, 0.1703],
         [0.5066, 0.1457, 0.1961, 0.2758],
         [0.2641, 0.0891, 0.2625, 0.1781],
         [0.2207, 0.2703, 0.1273, 0.1688],
         [0.9750, 0.3922, 0.0437, 0.2812],
         [0.9184, 0.0797, 0.1633, 0.0750]]), 'area': tensor([359513.0000,  30240.0000,  28413.0000,  10191.0000,  15728.2500,
          16888.7500,  15151.0000,  22150.7500,  19152.0000,   8802.0000,
           5040.0000,   5016.0000]), 'iscrowd': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), 'orig_size': tensor([640, 640])}}
In [16]:
def collate_fn(batch):
    data = {}
    data["pixel_values"] = torch.stack([x["pixel_values"] for x in batch])
    data["labels"] = [x["labels"] for x in batch]
    return data
In [17]:
id2label = {id: label for id, label in enumerate(ds_train.classes)}
label2id = {label: id for id, label in enumerate(ds_train.classes)}


@dataclass
class ModelOutput:
    logits: torch.Tensor
    pred_boxes: torch.Tensor


class MAPEvaluator:

    def __init__(self, image_processor, threshold=0.00, id2label=None):
        self.image_processor = image_processor
        self.threshold = threshold
        self.id2label = id2label

    def collect_image_sizes(self, targets):
        """Collect image sizes across the dataset as list of tensors with shape [batch_size, 2]."""
        image_sizes = []
        for batch in targets:
            batch_image_sizes = torch.tensor(np.array([x["size"] for x in batch]))
            image_sizes.append(batch_image_sizes)
        return image_sizes

    def collect_targets(self, targets, image_sizes):
        post_processed_targets = []
        for target_batch, image_size_batch in zip(targets, image_sizes):
            for target, (height, width) in zip(target_batch, image_size_batch):
                boxes = target["boxes"]
                boxes = sv.xcycwh_to_xyxy(boxes)
                boxes = boxes * np.array([width, height, width, height])
                boxes = torch.tensor(boxes)
                labels = torch.tensor(target["class_labels"])
                post_processed_targets.append({"boxes": boxes, "labels": labels})
        return post_processed_targets

    def collect_predictions(self, predictions, image_sizes):
        post_processed_predictions = []
        for batch, target_sizes in zip(predictions, image_sizes):
            batch_logits, batch_boxes = batch[1], batch[2]
            output = ModelOutput(logits=torch.tensor(batch_logits), pred_boxes=torch.tensor(batch_boxes))
            post_processed_output = self.image_processor.post_process_object_detection(
                output, threshold=self.threshold, target_sizes=target_sizes
            )
            post_processed_predictions.extend(post_processed_output)
        return post_processed_predictions

    @torch.no_grad()
    def __call__(self, evaluation_results):

        predictions, targets = evaluation_results.predictions, evaluation_results.label_ids

        image_sizes = self.collect_image_sizes(targets)
        post_processed_targets = self.collect_targets(targets, image_sizes)
        post_processed_predictions = self.collect_predictions(predictions, image_sizes)

        evaluator = MeanAveragePrecision(box_format="xyxy", class_metrics=True)
        evaluator.warn_on_many_detections = False
        evaluator.update(post_processed_predictions, post_processed_targets)

        metrics = evaluator.compute()

        # Replace list of per class metrics with separate metric for each class
        classes = metrics.pop("classes")
        map_per_class = metrics.pop("map_per_class")
        mar_100_per_class = metrics.pop("mar_100_per_class")
        for class_id, class_map, class_mar in zip(classes, map_per_class, mar_100_per_class):
            class_name = id2label[class_id.item()] if id2label is not None else class_id.item()
            metrics[f"map_{class_name}"] = class_map
            metrics[f"mar_100_{class_name}"] = class_mar

        metrics = {k: round(v.item(), 4) for k, v in metrics.items()}

        return metrics

eval_compute_metrics_fn = MAPEvaluator(image_processor=processor, threshold=0.01, id2label=id2label)
In [18]:
model = AutoModelForObjectDetection.from_pretrained(
    CHECKPOINT,
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True,
)
Some weights of the model checkpoint at facebook/detr-resnet-50 were not used when initializing DetrForObjectDetection: ['model.backbone.conv_encoder.model.layer1.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer2.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer3.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer4.0.downsample.1.num_batches_tracked']
- This IS expected if you are initializing DetrForObjectDetection from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DetrForObjectDetection from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DetrForObjectDetection were not initialized from the model checkpoint at facebook/detr-resnet-50 and are newly initialized because the shapes did not match:
- class_labels_classifier.bias: found shape torch.Size([92]) in the checkpoint and torch.Size([7]) in the model instantiated
- class_labels_classifier.weight: found shape torch.Size([92, 256]) in the checkpoint and torch.Size([7, 256]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
In [19]:
# https://datascience.stackexchange.com/questions/103022/warmup-steps-in-deep-learning

training_args = TrainingArguments(
    output_dir=f"{dataset.name.replace(' ', '-')}-finetune",
    num_train_epochs=64,
    max_grad_norm=0.01,
    learning_rate=5e-5,
    # lr_scheduler_type="cosine",
    warmup_ratio=0.06,
    per_device_train_batch_size=16,
    # dataloader_num_workers=2,
    metric_for_best_model="eval_map",
    greater_is_better=True,
    load_best_model_at_end=True,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=2,
    remove_unused_columns=False,
    eval_do_concat_batches=False,
    logging_strategy="epoch",
    # logging_steps=500,
    weight_decay=1e-4,
)
In [20]:
5640*16/120
Out[20]:
752.0
In [21]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=pytorch_dataset_train,
    eval_dataset=pytorch_dataset_valid,
    tokenizer=processor,
    data_collator=collate_fn,
    compute_metrics=eval_compute_metrics_fn,
)

trainer.train()
[29888/29888 16:59:56, Epoch 64/64]
Epoch Training Loss Validation Loss Map Map 50 Map 75 Map Small Map Medium Map Large Mar 1 Mar 10 Mar 100 Mar Small Mar Medium Mar Large Map Crack Mar 100 Crack Map Damp Mar 100 Damp Map Dampness Mar 100 Dampness Map Mold Mar 100 Mold Map Stain Mar 100 Stain
1 4.256100 3.031680 0.000900 0.003100 0.000400 0.000000 0.000100 0.001600 0.004800 0.016200 0.049200 0.000000 0.012800 0.069300 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.004700 0.246000 0.000000 0.000000
2 2.513900 2.576494 0.006400 0.014600 0.004400 0.000300 0.000800 0.008900 0.008800 0.027700 0.058500 0.002500 0.029900 0.077200 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.032100 0.292300 0.000000 0.000000
3 2.362400 2.445590 0.006900 0.017100 0.003800 0.000000 0.004100 0.009900 0.013900 0.032500 0.061200 0.003300 0.040300 0.077300 0.002000 0.004800 0.004000 0.007300 0.000000 0.000000 0.028400 0.293800 0.000000 0.000000
4 2.338100 2.554369 0.013800 0.029400 0.012000 0.000000 0.004200 0.019400 0.016000 0.046600 0.079900 0.000800 0.032400 0.108400 0.000200 0.009500 0.000200 0.014500 0.000000 0.000000 0.050200 0.324500 0.018400 0.051100
5 2.275100 2.430867 0.006900 0.022400 0.003400 0.000100 0.008100 0.012600 0.016200 0.051700 0.087000 0.008300 0.053600 0.110600 0.008500 0.050800 0.000500 0.010900 0.000000 0.000000 0.024600 0.309400 0.001000 0.063800
6 2.188400 2.305122 0.015300 0.037300 0.011400 0.000100 0.011900 0.019800 0.027600 0.078500 0.119000 0.001700 0.060200 0.156600 0.017800 0.134900 0.000500 0.049100 0.000000 0.000000 0.056600 0.328100 0.001700 0.083000
7 2.160900 2.215753 0.023700 0.056400 0.014100 0.001000 0.014700 0.034100 0.028700 0.076900 0.113400 0.003300 0.060900 0.148700 0.041300 0.139700 0.001900 0.036400 0.000000 0.000000 0.062000 0.335600 0.013300 0.055300
8 2.099000 2.173176 0.014100 0.038000 0.008700 0.001200 0.009800 0.021600 0.022500 0.082000 0.121100 0.012500 0.086600 0.149400 0.014200 0.176200 0.005300 0.049100 0.000000 0.000000 0.048800 0.339600 0.002200 0.040400
9 2.072800 2.191002 0.020500 0.048400 0.014500 0.000200 0.011200 0.032900 0.043400 0.110400 0.157600 0.007500 0.081900 0.208600 0.027800 0.290500 0.014400 0.105500 0.000000 0.000000 0.055100 0.332500 0.005100 0.059600
10 2.016600 2.104266 0.031400 0.073900 0.022500 0.001200 0.009200 0.045600 0.051500 0.137800 0.197200 0.011700 0.107900 0.259000 0.051300 0.317500 0.015600 0.141800 0.000000 0.000000 0.081800 0.352200 0.008000 0.174500
11 1.991400 2.143938 0.028300 0.073600 0.017300 0.000500 0.010800 0.043900 0.045100 0.110400 0.179500 0.011700 0.087200 0.238500 0.049100 0.290500 0.016400 0.172700 0.000000 0.000000 0.067200 0.328100 0.009100 0.106400
12 1.967800 2.159175 0.025600 0.061100 0.019000 0.000400 0.013400 0.039500 0.058700 0.121200 0.175600 0.007500 0.069700 0.239500 0.051500 0.285700 0.010500 0.145500 0.000000 0.000000 0.046100 0.340400 0.019800 0.106400
13 1.926400 2.185529 0.024000 0.059600 0.014800 0.003300 0.008600 0.040800 0.060200 0.131200 0.196300 0.013300 0.101300 0.258600 0.042200 0.306300 0.008000 0.174500 0.000000 0.000000 0.042200 0.328200 0.027700 0.172300
14 1.914200 2.136886 0.032200 0.068500 0.025500 0.000900 0.009600 0.050100 0.062600 0.140300 0.205100 0.005800 0.089900 0.277500 0.049100 0.390500 0.011200 0.120000 0.000000 0.000000 0.070900 0.353200 0.030000 0.161700
15 1.869200 2.090732 0.035300 0.084500 0.029400 0.001400 0.016000 0.050900 0.062600 0.143200 0.210200 0.005800 0.089200 0.286800 0.058300 0.368300 0.021800 0.121800 0.000000 0.000000 0.074700 0.347900 0.021400 0.212800
16 1.841000 2.066587 0.044600 0.094200 0.040100 0.002200 0.014600 0.064800 0.058300 0.139400 0.199000 0.005000 0.080800 0.272300 0.084600 0.369800 0.014300 0.089100 0.000000 0.000000 0.078200 0.344700 0.046100 0.191500
17 1.792400 2.016057 0.042200 0.097800 0.029900 0.000100 0.022500 0.064800 0.062800 0.154800 0.227900 0.003300 0.119300 0.302700 0.087300 0.414300 0.022900 0.109100 0.000000 0.000000 0.059000 0.356300 0.042000 0.259600
18 1.767400 2.113763 0.039200 0.086900 0.031900 0.000500 0.021000 0.062100 0.059600 0.135100 0.191600 0.005800 0.081300 0.262000 0.112700 0.317500 0.015900 0.167300 0.000000 0.000000 0.060100 0.364500 0.007300 0.108500
19 1.753400 2.056695 0.046800 0.099200 0.041000 0.001200 0.012100 0.076200 0.072700 0.156500 0.222600 0.010000 0.105500 0.297900 0.102800 0.354000 0.014800 0.201800 0.000000 0.000000 0.067800 0.359100 0.048400 0.197900
20 1.741100 1.968403 0.048700 0.105100 0.036000 0.002500 0.015700 0.076600 0.064700 0.160100 0.237000 0.010000 0.120900 0.313200 0.118200 0.371400 0.007900 0.190900 0.000000 0.000000 0.085900 0.373900 0.031700 0.248900
21 1.719100 2.001638 0.061800 0.126000 0.055100 0.000900 0.018900 0.093000 0.071500 0.158000 0.241600 0.005000 0.100200 0.329500 0.133700 0.371400 0.020200 0.161800 0.000000 0.000000 0.067000 0.366200 0.087900 0.308500
22 1.692000 2.061321 0.048900 0.117200 0.037400 0.002500 0.024100 0.074200 0.074000 0.173100 0.236800 0.010800 0.121700 0.315500 0.129100 0.374600 0.008900 0.210900 0.000000 0.000000 0.063200 0.362100 0.043400 0.236200
23 1.664800 2.042977 0.057100 0.124600 0.045000 0.001300 0.034100 0.082500 0.075800 0.174200 0.228400 0.006700 0.113600 0.306300 0.123600 0.363500 0.033600 0.174500 0.000000 0.000000 0.087300 0.353000 0.041200 0.251100
24 1.651000 1.969023 0.055500 0.131100 0.039500 0.005200 0.024400 0.081200 0.067500 0.179500 0.244700 0.011700 0.136700 0.320500 0.137200 0.384100 0.015800 0.201800 0.000000 0.000000 0.085300 0.358600 0.039200 0.278700
25 1.606500 1.960194 0.070100 0.144100 0.058200 0.003100 0.016300 0.107000 0.083300 0.188200 0.241300 0.005800 0.093700 0.329900 0.160900 0.336500 0.021400 0.240000 0.000000 0.000000 0.084500 0.376700 0.083900 0.253200
26 1.579200 2.005141 0.071400 0.150500 0.064700 0.007500 0.021300 0.105000 0.083400 0.176400 0.251500 0.030800 0.106700 0.339000 0.157500 0.387300 0.022800 0.160000 0.000000 0.000000 0.091100 0.376000 0.085400 0.334000
27 1.576500 1.962995 0.068500 0.141200 0.054600 0.007200 0.033300 0.103100 0.080300 0.186200 0.262900 0.040000 0.159700 0.339100 0.134600 0.425400 0.032600 0.225500 0.000000 0.000000 0.086100 0.376200 0.089100 0.287200
28 1.550400 1.913630 0.077600 0.167100 0.068300 0.004700 0.034500 0.112100 0.091300 0.199400 0.271000 0.009200 0.144900 0.356800 0.166200 0.430200 0.039100 0.229100 0.000000 0.000000 0.082600 0.376500 0.099900 0.319100
29 1.527700 1.921490 0.075900 0.156000 0.068900 0.001100 0.035000 0.110700 0.085700 0.196100 0.265900 0.008300 0.134500 0.354200 0.177600 0.390500 0.028200 0.249100 0.000000 0.000000 0.091200 0.385700 0.082300 0.304300
30 1.519100 1.935570 0.068200 0.149300 0.062700 0.012300 0.031400 0.106000 0.093700 0.187000 0.262000 0.054200 0.139200 0.344200 0.112100 0.417500 0.033000 0.274500 0.000000 0.000000 0.078400 0.384100 0.117600 0.234000
31 1.502800 1.930545 0.084700 0.175000 0.077500 0.004200 0.041500 0.118700 0.099100 0.206100 0.269600 0.010000 0.169200 0.347200 0.161900 0.412700 0.057700 0.256400 0.000000 0.000000 0.089500 0.385400 0.114600 0.293600
32 1.470000 1.953679 0.081200 0.155500 0.082000 0.005600 0.024100 0.119300 0.105300 0.192400 0.244400 0.052500 0.105000 0.329200 0.152500 0.373000 0.047500 0.214500 0.000000 0.000000 0.099300 0.393900 0.106800 0.240400
33 1.465000 1.934203 0.064300 0.138400 0.056900 0.003400 0.027000 0.097000 0.074200 0.197100 0.266700 0.006700 0.134900 0.354800 0.160200 0.381000 0.026300 0.330900 0.000000 0.000000 0.082500 0.374900 0.052600 0.246800
34 1.442100 1.925600 0.074000 0.157100 0.069000 0.002500 0.027500 0.116400 0.097800 0.191200 0.264500 0.012500 0.147300 0.347200 0.146400 0.425400 0.036400 0.254500 0.000000 0.000000 0.087600 0.387000 0.099400 0.255300
35 1.425400 1.934578 0.084200 0.167000 0.079800 0.007400 0.035000 0.124600 0.095400 0.199100 0.265300 0.012500 0.134000 0.351400 0.177300 0.427000 0.048000 0.261800 0.000000 0.000000 0.091600 0.390600 0.104300 0.246800
36 1.402200 1.934678 0.078700 0.172000 0.071500 0.002800 0.036600 0.116000 0.087400 0.202900 0.284300 0.009200 0.143600 0.379100 0.127200 0.460300 0.045900 0.294500 0.000000 0.000000 0.098500 0.383900 0.122200 0.283000
37 1.405300 1.956452 0.075300 0.160700 0.065100 0.007800 0.030100 0.110800 0.090600 0.197400 0.275700 0.024200 0.123700 0.371500 0.131200 0.412700 0.034600 0.243600 0.000000 0.000000 0.090600 0.386200 0.120000 0.336200
38 1.370700 1.902062 0.082100 0.167400 0.077100 0.003100 0.033600 0.123700 0.092700 0.202200 0.264500 0.012500 0.134800 0.351600 0.161400 0.449200 0.060500 0.241800 0.000000 0.000000 0.084400 0.384700 0.104100 0.246800
39 1.355500 1.942167 0.082200 0.154400 0.078500 0.003500 0.036000 0.127500 0.099500 0.204000 0.275900 0.010800 0.145300 0.365300 0.176300 0.407900 0.049100 0.285500 0.000000 0.000000 0.093800 0.392300 0.091900 0.293600
40 1.339500 1.934280 0.090300 0.175400 0.092500 0.002600 0.035700 0.136400 0.111100 0.218200 0.283000 0.008300 0.128300 0.382800 0.182800 0.431700 0.056400 0.312700 0.000000 0.000000 0.091400 0.387500 0.121000 0.283000
41 1.319200 1.933329 0.075200 0.150200 0.069800 0.006600 0.029800 0.122100 0.095000 0.221700 0.286300 0.010800 0.126400 0.385100 0.194500 0.363500 0.055300 0.392700 0.000000 0.000000 0.078000 0.383900 0.048500 0.291500
42 1.302000 1.939811 0.090300 0.174300 0.091200 0.007300 0.043200 0.132500 0.099200 0.217100 0.280100 0.020800 0.144700 0.367900 0.205800 0.398400 0.066000 0.349100 0.000000 0.000000 0.098200 0.391100 0.081400 0.261700
43 1.283900 1.949140 0.088500 0.172300 0.082500 0.015000 0.043900 0.128500 0.112500 0.216700 0.280600 0.027500 0.146700 0.368900 0.189400 0.396800 0.063800 0.334500 0.000000 0.000000 0.094600 0.392900 0.094500 0.278700
44 1.267800 1.946965 0.094100 0.187400 0.089100 0.024000 0.029700 0.139900 0.115800 0.230600 0.300800 0.027500 0.137700 0.402400 0.176200 0.449200 0.083400 0.385500 0.000000 0.000000 0.093100 0.377700 0.117500 0.291500
45 1.251200 1.947085 0.092300 0.186500 0.075400 0.016400 0.035000 0.141500 0.119600 0.228600 0.288600 0.032500 0.134300 0.385200 0.181800 0.366700 0.091600 0.370900 0.000000 0.000000 0.095600 0.390600 0.092500 0.314900
46 1.234900 1.921645 0.098400 0.207700 0.086100 0.019200 0.052200 0.144300 0.117800 0.242200 0.309900 0.036700 0.149200 0.410800 0.177900 0.436500 0.084500 0.383600 0.000000 0.000000 0.090800 0.386900 0.138900 0.342600
47 1.221800 1.950424 0.084500 0.172200 0.070900 0.014300 0.039600 0.126100 0.115200 0.225900 0.303500 0.040000 0.161400 0.396700 0.173100 0.439700 0.069600 0.385500 0.000000 0.000000 0.079900 0.383900 0.099900 0.308500
48 1.211700 1.940542 0.098900 0.187100 0.089300 0.017400 0.046300 0.141800 0.114600 0.243000 0.304000 0.032500 0.150800 0.401800 0.189400 0.452400 0.073500 0.358200 0.000000 0.000000 0.096900 0.392600 0.135000 0.317000
49 1.182200 1.942539 0.100500 0.194200 0.092700 0.012300 0.042900 0.150900 0.124900 0.239500 0.307700 0.042500 0.145500 0.408900 0.181900 0.419000 0.085400 0.369100 0.000000 0.000000 0.089700 0.384600 0.145200 0.366000
50 1.183800 1.982745 0.113300 0.208600 0.104900 0.026800 0.037500 0.166100 0.137000 0.236000 0.296100 0.032500 0.126600 0.397400 0.203800 0.404800 0.108100 0.336400 0.000000 0.000000 0.093600 0.390300 0.161000 0.348900
51 1.170200 1.969754 0.093800 0.191100 0.089600 0.032900 0.035900 0.143100 0.118300 0.235400 0.295300 0.042500 0.134600 0.394200 0.169100 0.390500 0.086100 0.356400 0.000000 0.000000 0.082200 0.397500 0.131700 0.331900
52 1.157500 1.952597 0.104300 0.213600 0.091500 0.017600 0.042300 0.151200 0.123200 0.239700 0.296700 0.043300 0.138100 0.393800 0.181400 0.420600 0.083800 0.343600 0.000000 0.000000 0.095700 0.385100 0.160900 0.334000
53 1.135800 1.964583 0.099400 0.198400 0.093000 0.015600 0.040800 0.144300 0.117700 0.235100 0.293500 0.045000 0.133500 0.389700 0.183000 0.407900 0.082600 0.336400 0.000000 0.000000 0.088600 0.380600 0.142900 0.342600
54 1.130700 1.940963 0.100200 0.207800 0.092500 0.023800 0.042500 0.147300 0.119200 0.242700 0.297800 0.030800 0.136800 0.396800 0.190000 0.409500 0.086700 0.343600 0.000000 0.000000 0.085600 0.382800 0.138800 0.353200
55 1.116500 1.931453 0.103500 0.194800 0.102600 0.027400 0.035000 0.152000 0.114500 0.234100 0.292100 0.037500 0.134500 0.388400 0.190700 0.384100 0.103000 0.376400 0.000000 0.000000 0.087300 0.389200 0.136500 0.310600
56 1.108500 1.983815 0.102500 0.201100 0.086000 0.007100 0.044500 0.150400 0.114700 0.229900 0.283100 0.022500 0.132800 0.377700 0.178000 0.352400 0.095800 0.361800 0.000000 0.000000 0.086800 0.375500 0.151600 0.325500
57 1.089700 1.951452 0.105900 0.205800 0.100800 0.010800 0.035800 0.154700 0.115300 0.233500 0.292200 0.042500 0.138400 0.386400 0.199800 0.398400 0.099200 0.385500 0.000000 0.000000 0.084100 0.383400 0.146500 0.293600
58 1.086300 1.973530 0.104400 0.215800 0.090900 0.010200 0.044500 0.152900 0.120800 0.237400 0.295200 0.026700 0.149900 0.388800 0.187100 0.387300 0.111800 0.374500 0.000000 0.000000 0.086900 0.380000 0.136000 0.334000
59 1.075100 1.962855 0.103200 0.208800 0.096000 0.004600 0.043800 0.152900 0.121700 0.242900 0.297300 0.016700 0.142200 0.394300 0.189200 0.387300 0.107800 0.387300 0.000000 0.000000 0.088100 0.386500 0.130700 0.325500
60 1.072100 1.946280 0.101600 0.211100 0.088700 0.009400 0.048200 0.151600 0.121400 0.241300 0.296400 0.025000 0.156100 0.388100 0.179700 0.400000 0.103800 0.380000 0.000000 0.000000 0.085500 0.386900 0.139000 0.314900
61 1.059900 1.966664 0.105500 0.213000 0.102900 0.008200 0.044200 0.154500 0.122800 0.239700 0.296800 0.022500 0.149500 0.392300 0.199200 0.398400 0.106900 0.387300 0.000000 0.000000 0.088400 0.387500 0.133100 0.310600
62 1.055500 1.963952 0.109200 0.217700 0.106200 0.011900 0.044000 0.161500 0.128700 0.242500 0.295800 0.031700 0.142000 0.392200 0.199600 0.377800 0.112900 0.381800 0.000000 0.000000 0.085900 0.387400 0.147600 0.331900
63 1.055300 1.977658 0.109600 0.217600 0.107100 0.011100 0.046300 0.159300 0.126600 0.242800 0.295700 0.026700 0.138000 0.394000 0.199600 0.376200 0.114300 0.385500 0.000000 0.000000 0.085200 0.384900 0.148900 0.331900
64 1.051000 1.974970 0.109100 0.218400 0.106200 0.011300 0.048300 0.158800 0.125800 0.244200 0.298400 0.026700 0.148300 0.394400 0.202300 0.390500 0.108400 0.380000 0.000000 0.000000 0.086100 0.385400 0.148800 0.336200

Out[21]:
TrainOutput(global_step=29888, training_loss=1.5771031553280686, metrics={'train_runtime': 61199.7935, 'train_samples_per_second': 7.812, 'train_steps_per_second': 0.488, 'total_flos': 1.4619877967344435e+20, 'train_loss': 1.5771031553280686, 'epoch': 64.0})
In [22]:
# pip install pycocotools
In [23]:
# pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu124
In [24]:
# pip install torch==2.3.0+cu121
# pip install torch==0.18.0+cu121
In [ ]:
 
In [25]:
from pprint import pprint

# metrics = trainer.evaluate(eval_dataset=pytorch_dataset_train, metric_key_prefix="train")
# pprint(metrics)
In [26]:
metrics = trainer.evaluate(eval_dataset=pytorch_dataset_valid, metric_key_prefix="eval")
pprint(metrics)
[27/27 00:21]
{'epoch': 64.0,
 'eval_loss': 1.9827451705932617,
 'eval_map': 0.1133,
 'eval_map_50': 0.2086,
 'eval_map_75': 0.1049,
 'eval_map_crack': 0.2038,
 'eval_map_damp': 0.1081,
 'eval_map_dampness': 0.0,
 'eval_map_large': 0.1661,
 'eval_map_medium': 0.0375,
 'eval_map_mold': 0.0936,
 'eval_map_small': 0.0268,
 'eval_map_stain': 0.161,
 'eval_mar_1': 0.137,
 'eval_mar_10': 0.236,
 'eval_mar_100': 0.2961,
 'eval_mar_100_crack': 0.4048,
 'eval_mar_100_damp': 0.3364,
 'eval_mar_100_dampness': 0.0,
 'eval_mar_100_mold': 0.3903,
 'eval_mar_100_stain': 0.3489,
 'eval_mar_large': 0.3974,
 'eval_mar_medium': 0.1266,
 'eval_mar_small': 0.0325,
 'eval_runtime': 13.6742,
 'eval_samples_per_second': 15.723,
 'eval_steps_per_second': 1.975}
In [27]:
from pprint import pprint

metrics = trainer.evaluate(eval_dataset=pytorch_dataset_test, metric_key_prefix="test")
pprint(metrics)
{'epoch': 64.0,
 'test_loss': 1.898290753364563,
 'test_map': 0.1441,
 'test_map_50': 0.2736,
 'test_map_75': 0.1354,
 'test_map_crack': 0.1373,
 'test_map_damp': 0.1711,
 'test_map_dampness': -1.0,
 'test_map_large': 0.1949,
 'test_map_medium': 0.071,
 'test_map_mold': 0.1043,
 'test_map_small': 0.1369,
 'test_map_stain': 0.1638,
 'test_mar_1': 0.1323,
 'test_mar_10': 0.2867,
 'test_mar_100': 0.3558,
 'test_mar_100_crack': 0.3871,
 'test_mar_100_damp': 0.3586,
 'test_mar_100_dampness': -1.0,
 'test_mar_100_mold': 0.3891,
 'test_mar_100_stain': 0.2885,
 'test_mar_large': 0.4502,
 'test_mar_medium': 0.1843,
 'test_mar_small': 0.15,
 'test_runtime': 11.1986,
 'test_samples_per_second': 16.341,
 'test_steps_per_second': 2.054}

Predictions for validation¶

In [28]:
# @title Collect predictions

targets = []
predictions = []

for i in range(len(ds_valid)):
    path, source_image, annotations = ds_valid[i]

    image = Image.open(path)
    inputs = processor(image, return_tensors="pt").to(DEVICE)

    with torch.no_grad():
        outputs = model(**inputs)

    w, h = image.size
    results = processor.post_process_object_detection(
        outputs, target_sizes=[(h, w)], threshold=0.3)

    detections = sv.Detections.from_transformers(results[0])

    targets.append(annotations)
    predictions.append(detections)
In [29]:
# @title Calculate mAP
mean_average_precision = sv.MeanAveragePrecision.from_detections(
    predictions=predictions,
    targets=targets,
)

print(f"map50_95: {mean_average_precision.map50_95:.2f}")
print(f"map50: {mean_average_precision.map50:.2f}")
print(f"map75: {mean_average_precision.map75:.2f}")
map50_95: 0.13
map50: 0.22
map75: 0.14

Predictions for test¶

In [30]:
# @title Collect predictions

targets = []
predictions = []

for i in range(len(ds_test)):
    path, source_image, annotations = ds_test[i]

    image = Image.open(path)
    inputs = processor(image, return_tensors="pt").to(DEVICE)

    with torch.no_grad():
        outputs = model(**inputs)

    w, h = image.size
    results = processor.post_process_object_detection(
        outputs, target_sizes=[(h, w)], threshold=0.3)

    detections = sv.Detections.from_transformers(results[0])

    targets.append(annotations)
    predictions.append(detections)
In [31]:
# @title Calculate mAP
mean_average_precision = sv.MeanAveragePrecision.from_detections(
    predictions=predictions,
    targets=targets,
)

print(f"map50_95: {mean_average_precision.map50_95:.2f}")
print(f"map50: {mean_average_precision.map50:.2f}")
print(f"map75: {mean_average_precision.map75:.2f}")
map50_95: 0.16
map50: 0.27
map75: 0.17
In [32]:
# @title Calculate Confusion Matrix
confusion_matrix = sv.ConfusionMatrix.from_detections(
    predictions=predictions,
    targets=targets,
    classes=ds_test.classes
)

_ = confusion_matrix.plot()
No description has been provided for this image
In [33]:
# model.save_pretrained("/content/rt-detr/")
# processor.save_pretrained("/content/rt-detr/")
In [34]:
IMAGE_COUNT = 100

for i in range(IMAGE_COUNT):
    path, source_image, annotations = ds_test[i]

    image = Image.open(path)
    inputs = processor(image, return_tensors="pt").to(DEVICE)

    with torch.no_grad():
        outputs = model(**inputs)

    w, h = image.size
    results = processor.post_process_object_detection(
        outputs, target_sizes=[(h, w)], threshold=0.3)

    detections = sv.Detections.from_transformers(results[0]).with_nms(threshold=0.1)

    annotated_images = [
        annotate(source_image, annotations, ds_train.classes),
        annotate(source_image, detections, ds_train.classes)
    ]
    grid = sv.create_tiles(
        annotated_images,
        titles=['ground truth', 'prediction'],
        titles_scale=0.5,
        single_tile_size=(400, 400),
        tile_padding_color=sv.Color.WHITE,
        tile_margin_color=sv.Color.WHITE
    )
    sv.plot_image(grid, size=(6, 6))
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
In [35]:
trainer.save_model("artifacts/")
processor.save_pretrained("artifacts/")
Out[35]:
['artifacts/preprocessor_config.json']
In [ ]: